from utils import create_one_dim_inf_dim_market, solve_one_dim_inf_dim_market, square_integrate
# from utils import eval
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import scipy.stats as sps

font = {'size': 24}
sns.set_theme()

n = 50
B, c, d = create_one_dim_inf_dim_market(n, sd=2022, normalize=True)
u_eq_true, place_of_buyer, bpts = solve_one_dim_inf_dim_market(B, c, d, return_allocation=True)
nsw_true = (B * np.log(u_eq_true)).sum()

# compute variance by integrating p_eq
beta_eq_true = B / u_eq_true
variance_true = 0
# p_integral = 0
for i in range(n):
    j = place_of_buyer[i]
    l, r = bpts[j], bpts[j+1]
    variance_true += (beta_eq_true[i] ** 2) * square_integrate(c[i], d[i], l, r)
    # p_integral += beta_eq_true[i] * eval(c[i], d[i], l, r)
variance_true -= 1

loaded = []
for sd in range(1, 51):
    with np.load(f'results/nsw_list_clt_n_50_one_dim_sd_{sd}.npz') as data:
        if sd == 1: t_list = data['t_list']
        nsw_list = data['nsw_list']
        loaded.append(nsw_list)
loaded = np.array(loaded)

# plot NSW values corr. to maximum t value
t = t_list[-1]
nsw_values = loaded[:, -1]
scaled_nsw_values = (t**0.5) * (nsw_values - nsw_true)
plt.hist((t**0.5) * (nsw_values - nsw_true), density=True, label=r'$\sqrt{t}({\rm NSW}^{\gamma} - {\rm NSW}^*)$')
normal_dist_x = np.linspace(-0.3, 0.3, 1000)
normal_dist_y = sps.norm.pdf(normal_dist_x, loc=0, scale=variance_true**0.5)
plt.plot(normal_dist_x, normal_dist_y, label=r'${\rm}N(0, \sigma_N^2)$')

plt.legend()
# plt.title(r'$v_i(\theta) = \alpha_i^\top \theta + c_i$')
plt.savefig('plots/nsw_one_dim_clt_verify.pdf', bbox_inches = 'tight')